#!/usr/bin/env python
#
# convert lrg data into tx and exonset files
#
# LRG provides the following downloads:
# zip file of all XML files - includes TX coords in TX space -> use for tx_info
# zip file of all fasta files
# tsv file of summary data w/ genomic coords -> use for exon_set
#
# http://www.lrg-sequence.org
#
from __future__ import with_statement
import argparse
import operator
import csv
import operator
import os
import xml.etree.ElementTree

from bioutils.accessions import chr_to_NC

import uta.formats.exonset
import uta.formats.txinfo


def main():
    parser = argparse.ArgumentParser(
        description="convert LRG files to format to import to UTA")
    parser.add_argument("-t", action="store", dest="tabfile",
                        help="input tab-delimited text file from LRG")
    parser.add_argument(
        "-x", action="store", dest="xmldir", help="input XML dir from LRG")
    parser.add_argument("-o", action="store", dest="outdir",
                        help="directory to write output tx and exonset files")

    args_in = parser.parse_args()

    if not os.path.exists(args_in.outdir):
        os.makedirs(args_in.outdir)

    tx_acs_hgnc = {}    # save this for the TX file

    # write exon set file
    with open(args_in.tabfile, "r") as f:
        _ = f.readline()   # discard the 1st line - timestamp
        reader = csv.DictReader(f, delimiter="\t")
        exonset_lines = set()

        for row in reader:
            # setup exon pair string
            exon_pairs = [x.split("-") for x in row["EXONS_COORDS"].split(",")]
            exon_coords = [(int(start), int(end))
                           for (start, end) in exon_pairs]
            g_list = ["{},{}".format(start - 1, end)
                      for (start, end) in exon_coords]
            if row["STRAND"] == "-1":
                g_list.reverse()
            g_exon_coords_string = ";".join(g_list)

            exonset_lines.add((row["# LRG_TRANSCRIPT"],
                               chr_to_NC[row["CHROMOSOME"]],
                               "LRG",
                               row["STRAND"], g_exon_coords_string))
            tx_acs_hgnc[row["# LRG_TRANSCRIPT"]] = row["HGNC_SYMBOL"]

    with open(os.path.join(args_in.outdir, "lrg.exonset"), "w") as f:
        esw = uta.formats.exonset.ExonSetWriter(f)
        exonset_lines = list(exonset_lines)
        exonset_lines = sorted(exonset_lines, key=operator.itemgetter(0))
        for exonset_line in exonset_lines:
            exonset_row = uta.formats.exonset.ExonSet(*exonset_line)
            esw.write(exonset_row)

    # write tx info file
    with open(os.path.join(args_in.outdir, "lrg.txinfo"), "w") as f_out:
        tw = uta.formats.txinfo.TxInfoWriter(f_out)

        tx_acs = tx_acs_hgnc.keys()
        tx_acs.sort()

        for tx_ac in tx_acs:

            lrg_root = tx_ac[:tx_ac.find("t")]
            # TX are LRG_7t1; XML file name is LRG_7.xml
            xml_fn = "{}.xml".format(lrg_root)
            xml_filepath = os.path.join(args_in.xmldir, xml_fn)

            if os.path.exists(xml_filepath):
                tree = xml.etree.ElementTree.parse(xml_filepath)
                root = tree.getroot()

                tx_name = tx_ac[tx_ac.find("t"):]
                tx_elements = root.findall("./fixed_annotation/transcript")
                no_match = True
                for tx_element in tx_elements:
                    if tx_name == tx_element.attrib["name"]:    # match
                        no_match = False

                        # setup CDS coords - convert from LRG coords to
                        # interbase (subtract 1 from start)
                        exons = tx_element.findall("./exon")
                        cds_start_end_tag = tx_element.find(
                            "./coding_region/coordinates")
                        cds_start_lrg = int(cds_start_end_tag.attrib["start"])
                        cds_end_lrg = int(cds_start_end_tag.attrib["end"])

                        match_start = match_end = False
                        for exon in exons:
                            exon_coord_lrg = exon.find(
                                "./coordinates[@coord_system="{}"]".format(lrg_root))
                            exon_lrg_start = int(
                                exon_coord_lrg.attrib["start"])
                            exon_lrg_end = int(exon_coord_lrg.attrib["end"])
                            exon_coord_tx = exon.find(
                                "./coordinates[@coord_system="{}"]".format(tx_ac))
                            exon_tx_start = int(exon_coord_tx.attrib["start"])

                            if exon_lrg_start <= cds_start_lrg <= exon_lrg_end:
                                match_start = True
                                cds_start_i = cds_start_lrg - \
                                    exon_lrg_start + exon_tx_start - 1

                            if exon_lrg_start <= cds_end_lrg <= exon_lrg_end:
                                match_end = True
                                cds_end_i = cds_end_lrg - \
                                    exon_lrg_start + exon_tx_start

                            if match_start and match_end:
                                break

                        # should never happen
                        if not (match_start and match_end):
                            print "ERROR: failed to find cds start and/or end for accession {}".format(tx_ac)
                            break

                        tx_cds_se_i = "{},{}".format(cds_start_i, cds_end_i)

                        # setup exon coords - convert to interbase (subtract 1
                        # from start)
                        exon_coords = []
                        for exon in exons:
                            exon_coord = exon.find(
                                "./coordinates[@coord_system="{}"]".format(tx_ac))
                            exon_coords.append(
                                ((int(exon_coord.attrib["start"])) - 1, int(exon_coord.attrib["end"])))

                        exon_coords = sorted(
                            exon_coords, key=operator.itemgetter(0))
                        tx_exon_se_i = ";".join(
                            ["{},{}".format(start, end) for (start, end) in exon_coords])

                        # finally - output result
                        tx_row = uta.formats.txinfo.TxInfo("LRG", tx_ac, tx_acs_hgnc[tx_ac],
                                                           tx_cds_se_i, tx_exon_se_i)
                        tw.write(tx_row)

                        break

                if no_match:
                    print "WARNING: tx {} not found for accession {}".format(tx_name, tx_ac)

            else:
                print "WARNING: xml file {} not found".format(xml_filepath)


if __name__ == "__main__":
    main()
